from datasets import Dataset, load_dataset, concatenate_datasets
from distilabel.models import vLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration
from distilabel.steps import (KeepColumns, FormatTextGenerationSFT)
import shutil
import os
import pandas as pd
import logging
import argparse
import json
import random
import time
import tqdm

#from tqdm import tqdm


pipeline_cache = '/root/.cache/distilabel/pipelines/qwen-7b-1M-majority'
if os.path.exists(pipeline_cache):
    shutil.rmtree(pipeline_cache)

prompt_template = """{{entire_instruction}}
"""

dataset_tomi = load_dataset("json", data_files=".../data/test_balanced.json", split="train[800:1000]")

def add_combined_column_tomi(dataset):
    def combine_text(example):
        # Ensure choices is properly formatted - could be a list or string
        choices_text = example["containers"]
        if isinstance(choices_text, list):
            choices_text = ", ".join(choices_text)
            
        # Create combined text
        example["entire_instruction"] = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
        
        #example["entire_instruction"] = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
        return example
    
    # Apply the transformation to each example
    return dataset.map(combine_text)

# Apply the function to your dataset
dataset_tomi = add_combined_column_tomi(dataset_tomi)
print(dataset_tomi)
print(dataset_tomi[0])

dataset_test_hitom_1 = load_dataset(".../ToM_data/Hi-ToM", split="train[80:100]")
dataset_test_hitom_2 = load_dataset(".../ToM_data/Hi-ToM", split="train[180:200]")
dataset_test_hitom_3 = load_dataset(".../ToM_data/Hi-ToM", split="train[280:300]")

dataset_test_hitom_4 = load_dataset(".../ToM_data/Hi-ToM", split="train[680:700]")
dataset_test_hitom_5 = load_dataset(".../ToM_data/Hi-ToM", split="train[780:800]")
dataset_test_hitom_6 = load_dataset(".../ToM_data/Hi-ToM", split="train[880:900]")

dataset_hitom = concatenate_datasets([dataset_test_hitom_1, dataset_test_hitom_2, dataset_test_hitom_3, dataset_test_hitom_4, dataset_test_hitom_5, dataset_test_hitom_6])

def add_combined_column_hitom(dataset):
    def combine_text(example):
        # Ensure choices is properly formatted - could be a list or string
        choices_text = example["choices"]
        # Create combined text
        example["entire_instruction"] = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
        #example["entire_instruction"] = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
        return example
    
    # Apply the transformation to each example
    return dataset.map(combine_text)

# Apply the function to your dataset
dataset_hitom = add_combined_column_hitom(dataset_hitom)
print(dataset_hitom)
print(dataset_hitom[0])


dataset_tombench = load_dataset("json", data_files=".../ToMbench_data/Moral Emotions.json", split="train[:40]")
def add_combined_column_tombench(dataset):
    def combine_text(example):
        # Ensure choices is properly formatted - could be a list or string
        # Create combined text
        option_A = example["OPTION-A"]
        option_B = example["OPTION-B"]
        option_C = example["OPTION-C"]
        option_D = example["OPTION-D"]

        formatted_string = ""
        formatted_string += "A: " + option_A + " "
        formatted_string += "B: " + option_B
        if option_C != None:
            formatted_string += " " + "C: " + option_C
        if option_D != None:
            formatted_string += " " + "D: " + option_D

        if example["答案\nANSWER"] == 'A':
            answer = option_A
        elif example["答案\nANSWER"] == 'B':
            answer = option_B
        elif example["答案\nANSWER"] == 'C':
            answer = option_C
        else:
            answer = option_D

        example["entire_instruction"] = f"Story: {example['STORY']} Question: {example['QUESTION']} Choices: {formatted_string}"

        #example["entire_instruction"] = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
        return example
    
    # Apply the transformation to each example
    return dataset.map(combine_text)

dataset_tombench = add_combined_column_tombench(dataset_tombench)
print(dataset_tombench)
print(dataset_tombench[0])

dataset_tomato = load_dataset("json", data_files=".../ToMATO/dataset/tomato_second.json", split="train[:25]")
def add_combined_column_tomato(dataset):
    def combine_text(example):
        # Ensure choices is properly formatted - could be a list or string
        # Create combined text
        a1 = example["a0"]
        a2 = example["a1"]
        a3 = example["a2"]
        a4 = example["a3"]
                
        formatted_string = ""
        formatted_string += "A: " + a1 + " " + "B: " + a2 + " " + "C: " + a3 + " " + "D: " + a4

        example["entire_instruction"] = f"Conversation: {example['conversation']} Question: {example['q']} Choices: {formatted_string}"

        #example["entire_instruction"] = f"Story: {example['story']} Question: {example['question']} Choices: {choices_text}"
        return example
    
    # Apply the transformation to each example
    return dataset.map(combine_text)

dataset_tomato = add_combined_column_tomato(dataset_tomato)
print(dataset_tomato)
print(dataset_tomato[0])


entire_dataset = dataset_tombench




model_id = ".../Qwen2.5-7B-Instruct-1M"

with Pipeline(
    name="qwen-7b-1M-majority",
    description="A pipeline to generate data from a qwen model",
) as pipeline:

    llm = vLLM(
        model=model_id,
        tokenizer=model_id,
        extra_kwargs={
            "tensor_parallel_size": 1,
            "max_model_len": 8192,
        },
        generation_kwargs={
            "temperature": 0.7,
            "max_new_tokens": 8192,
        },
    )


    text_generation = TextGeneration(
        llm=llm, 
        template=prompt_template,
        num_generations=8,
        input_batch_size=4,
        columns = ["entire_instruction"],
    )

    



if __name__ == "__main__":
    distiset = pipeline.run(dataset=entire_dataset)
    print(distiset)
    print(distiset['default']['train'][0])
    distiset.save_to_disk(".../SFTData/entire_majority")
    